# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# Standard library imports
from typing import Tuple

# Third-party imports
import mxnet as mx

# First-party imports
from gluonts.core.component import DType, validated
from gluonts.model.common import Tensor
from gluonts.mx.distribution.distribution import softplus
from gluonts.mx.kernels import KernelOutputDict

from .gaussian_process import GaussianProcess


class GaussianProcessNetworkBase(mx.gluon.HybridBlock):
    """
    Defines a Gluon block used for GP training and predictions.
    """

    # The two subclasses GaussianProcessTrainingNetwork and
    # GaussianProcessPredictionNetwork define how to
    # compute the loss and how to generate predictions, respectively.

    @validated()
    def __init__(
        self,
        prediction_length: int,
        context_length: int,
        cardinality: int,
        kernel_output: KernelOutputDict,
        params_scaling: bool,
        ctx: mx.Context,
        float_type: DType,
        max_iter_jitter: int,
        jitter_method: str,
        **kwargs,
    ) -> None:
        """
        Parameters
        ----------
        prediction_length
            Prediction length.
        context_length
            Training length.
        cardinality
            Number of time series.
        kernel_output
            KernelOutput instance to determine which kernel subclass to be instantiated.
        params_scaling
            Determines whether or not to scale the model parameters.
        ctx
            Determines whether to compute on the cpu or gpu.
        float_type
            Determines whether to use single or double precision.
        max_iter_jitter
            Maximum number of iterations for jitter to iteratively make the matrix positive definite.
        jitter_method
            Iteratively jitter method or use eigenvalue decomposition depending on problem size.
        **kwargs
            Arbitrary keyword arguments.
        """
        super().__init__(**kwargs)

        self.prediction_length = prediction_length
        self.context_length = context_length
        self.cardinality = cardinality
        self.kernel_output = kernel_output
        self.params_scaling = params_scaling
        self.float_type = float_type
        self.ctx = ctx
        self.max_iter_jitter = max_iter_jitter
        self.jitter_method = jitter_method

        with self.name_scope():
            self.proj_kernel_args = kernel_output.get_args_proj(
                self.float_type
            )
            self.num_hyperparams = kernel_output.get_num_args()
            self.embedding = mx.gluon.nn.Embedding(
                # Noise sigma is additional parameter so add 1 to output dim
                input_dim=self.cardinality,
                output_dim=self.num_hyperparams + 1,
                dtype=self.float_type,
            )

    # noinspection PyMethodOverriding,PyPep8Naming
    def get_gp_params(
        self,
        F,
        past_target: Tensor,
        past_time_feat: Tensor,
        feat_static_cat: Tensor,
    ) -> Tuple:
        """
        This function returns the GP hyper-parameters for the model.

        Parameters
        ----------
        F
            A module that can either refer to the Symbol API or the NDArray
            API in MXNet.
        past_target
            Training time series values of shape (batch_size, context_length).
        past_time_feat
            Training features of shape (batch_size, context_length, num_features).
        feat_static_cat
            Time series indices of shape (batch_size, 1).

        Returns
        -------
        Tuple
            Tuple of kernel hyper-parameters of length num_hyperparams.
                Each is a Tensor of shape (batch_size, 1, 1).
            Model noise sigma.
                Tensor of shape (batch_size, 1, 1).
        """
        output = self.embedding(
            feat_static_cat.squeeze()
        )  # Shape (batch_size, num_hyperparams + 1)
        kernel_args = self.proj_kernel_args(output)
        sigma = softplus(
            F,
            output.slice_axis(  # sigma is the last hyper-parameter
                axis=1,
                begin=self.num_hyperparams,
                end=self.num_hyperparams + 1,
            ),
        )
        if self.params_scaling:
            scalings = self.kernel_output.gp_params_scaling(
                F, past_target, past_time_feat
            )
            sigma = F.broadcast_mul(sigma, scalings[self.num_hyperparams])
            kernel_args = (
                F.broadcast_mul(kernel_arg, scaling)
                for kernel_arg, scaling in zip(
                    kernel_args, scalings[0 : self.num_hyperparams]
                )
            )
        min_value = 1e-5
        max_value = 1e8
        kernel_args = (
            kernel_arg.clip(min_value, max_value).expand_dims(axis=2)
            for kernel_arg in kernel_args
        )
        sigma = sigma.clip(min_value, max_value).expand_dims(axis=2)
        return kernel_args, sigma


class GaussianProcessTrainingNetwork(GaussianProcessNetworkBase):
    # noinspection PyMethodOverriding,PyPep8Naming
    @validated()
    def __init__(self, *args, **kwargs) -> None:
        """
        Parameters
        ----------
        *args
            Variable length argument list.
        **kwargs
            Arbitrary keyword arguments.
        """
        super().__init__(*args, **kwargs)

    # noinspection PyMethodOverriding,PyPep8Naming
    def hybrid_forward(
        self,
        F,
        past_target: Tensor,
        past_time_feat: Tensor,
        feat_static_cat: Tensor,
    ) -> Tensor:
        """
        Parameters
        ----------
        F
            A module that can either refer to the Symbol API or the NDArray
            API in MXNet.
        past_target
            Training time series values of shape (batch_size, context_length).
        past_time_feat
            Training features of shape (batch_size, context_length, num_features).
        feat_static_cat
            Time series indices of shape (batch_size, 1).
        Returns
        -------
        Tensor
            GP loss of shape (batch_size, 1)
        """
        kernel_args, sigma = self.get_gp_params(
            F, past_target, past_time_feat, feat_static_cat
        )
        kernel = self.kernel_output.kernel(kernel_args)
        gp = GaussianProcess(
            sigma=sigma,
            kernel=kernel,
            context_length=self.context_length,
            ctx=self.ctx,
            float_type=self.float_type,
            max_iter_jitter=self.max_iter_jitter,
            jitter_method=self.jitter_method,
        )
        return gp.log_prob(past_time_feat, past_target)


class GaussianProcessPredictionNetwork(GaussianProcessNetworkBase):
    @validated()
    def __init__(
        self, num_parallel_samples: int, sample_noise: bool, *args, **kwargs
    ) -> None:
        r"""
        Parameters
        ----------
        num_parallel_samples
            Number of samples to be drawn.
        sample_noise
            Boolean to determine whether to add :math:`\sigma^2I` to the predictive covariance matrix.
        *args
            Variable length argument list.
        **kwargs
            Arbitrary keyword arguments.
        """
        super().__init__(*args, **kwargs)
        self.num_parallel_samples = num_parallel_samples
        self.sample_noise = sample_noise

    # noinspection PyMethodOverriding,PyPep8Naming
    def hybrid_forward(
        self,
        F,
        past_target: Tensor,
        past_time_feat: Tensor,
        future_time_feat: Tensor,
        feat_static_cat: Tensor,
    ) -> Tensor:
        """
        Parameters
        ----------
        F
            A module that can either refer to the Symbol API or the NDArray
            API in MXNet.
        past_target
            Training time series values of shape (batch_size, context_length).
        past_time_feat
            Training features of shape (batch_size, context_length, num_features).
        future_time_feat
            Test features of shape (batch_size, prediction_length, num_features).
        feat_static_cat
            Time series indices of shape (batch_size, 1).
        Returns
        -------
        Tensor
            GP samples of shape (batch_size, num_samples, prediction_length).
        """
        kernel_args, sigma = self.get_gp_params(
            F, past_target, past_time_feat, feat_static_cat
        )
        gp = GaussianProcess(
            sigma=sigma,
            kernel=self.kernel_output.kernel(kernel_args),
            context_length=self.context_length,
            prediction_length=self.prediction_length,
            num_samples=self.num_parallel_samples,
            ctx=self.ctx,
            float_type=self.float_type,
            max_iter_jitter=self.max_iter_jitter,
            jitter_method=self.jitter_method,
            sample_noise=self.sample_noise,
        )
        samples, _, _ = gp.exact_inference(
            past_time_feat, past_target, future_time_feat
        )  # Shape (batch_size, prediction_length, num_samples)
        return samples.swapaxes(1, 2)
